import argparse
import psutil
from sac import *
from noda.noda_sac import *
from utils.run_utils import setup_logger_kwargs
import copy
import sys
import os
import pdb


def noda_main(args=None, redirect=True):
    if args is None:
        parser = argparse.ArgumentParser()
        parser.add_argument('--env', type=str, default='HalfCheetah-v3')
        parser.add_argument('--hid', type=int, default=256)
        parser.add_argument('--l', type=int, default=2)
        parser.add_argument('--gamma', type=float, default=0.99)
        parser.add_argument('--seed', '-s', type=int, default=0)
        parser.add_argument('--epochs', type=int, default=50)
        parser.add_argument('--exp-name', type=str, default='noda')
        parser.add_argument('--lat-noda', type=int, default=30)
        parser.add_argument('--hid-noda-ae', type=int, default=256)
        parser.add_argument('--hid-noda-ode', type=int, default=64)
        parser.add_argument('--steps-per-epoch', type=int, default=4000)
        parser.add_argument('--model-step', type=int, default=1)
        parser.add_argument('--model-data-ratio', type=float, default=1)
        parser.add_argument('--update-action-turns', type=int, default=3)
        parser.add_argument('--update-model-interval', type=int, default=1)
        parser.add_argument('--explore-lr', type=float, default=0.01)
        parser.add_argument('--update-every', type=int, default=50)
        parser.add_argument('--use-ode', type=int, default=1)
        parser.add_argument('--noise', type=float, default=0.0)
        parser.add_argument('--device', type=str, default=None)
        args = parser.parse_args()
        if args.device is None:
            args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    if args.use_ode:
        save_path = 'results/noda'
    else:
        save_path = 'results/noda_mlp'
    if not os.path.isdir(save_path):
        os.makedirs(save_path)
    if redirect:
        stdout = sys.stdout
        f = open(save_path + '/' + args.exp_name + '_s' + str(args.seed) + '.txt', 'w+')
        sys.stdout = f
    else:
        f = None
        stdout = None
    logger_kwargs = setup_logger_kwargs(args.exp_name, args.seed,
                                        os.path.dirname(os.path.realpath('__file__')) + '/' + save_path)
    torch.set_num_threads(1)
    logger = noda(lambda: gym.make(args.env), actor_critic=core.MLPActorCritic,
                  ac_kwargs=dict(hidden_sizes=[args.hid] * args.l),
                  model=NODANoPartial,
                  model_kwargs=dict(latent_dim=args.lat_noda,
                                    hidden_dim_ode=args.hid_noda_ode,
                                    hidden_dim_ae=args.hid_noda_ae,
                                    update_action_turns=args.update_action_turns,
                                    model_step=args.model_step,
                                    model_data_ratio=args.model_data_ratio,
                                    update_model_interval=args.update_model_interval,
                                    explore_lr=args.explore_lr,
                                    use_ode=args.use_ode),
                  gamma=args.gamma, seed=args.seed, epochs=args.epochs,
                  logger_kwargs=logger_kwargs, device=args.device,
                  steps_per_epoch=args.steps_per_epoch, update_every=args.update_every, noise=args.noise)
    if redirect:
        f.close()
        sys.stdout = stdout
    return logger


if __name__ == '__main__':
    noda_main(redirect=True)
